// SPDX-License-Identifier: MPL-2.0

//! Manages the kernel heap using slab or buddy allocation strategies.

use core::{
    alloc::{AllocError, GlobalAlloc, Layout},
    ptr::NonNull,
};

use crate::mm::Vaddr;

mod slab;
mod slot;
mod slot_list;

pub use self::{
    slab::{Slab, SlabMeta},
    slot::{HeapSlot, SlotInfo},
    slot_list::SlabSlotList,
};

/// The trait for the global heap allocator.
///
/// By providing the slab ([`Slab`]) and heap slot ([`HeapSlot`])
/// mechanisms, OSTD allows users to implement their own kernel heap in a safe
/// manner, as an alternative to the unsafe [`core::alloc::GlobalAlloc`].
///
/// To provide the global heap allocator, use [`crate::global_heap_allocator`]
/// to mark a static variable that implements this trait. Use
/// [`crate::global_heap_allocator_slot_map`] to specify the sizes of
/// slots for different layouts. This latter restriction may be lifted in the
/// future.
pub trait GlobalHeapAllocator: Sync {
    /// Allocates a [`HeapSlot`] according to the layout.
    ///
    /// OSTD calls this method to allocate memory from the global heap.
    ///
    /// The returned [`HeapSlot`] must be valid for the layout, i.e., the size
    /// must be at least the size of the layout and the alignment must be at
    /// least the alignment of the layout. Furthermore, the size of the
    /// returned [`HeapSlot`] must match the size returned by the function
    /// marked with [`crate::global_heap_allocator_slot_map`].
    fn alloc(&self, layout: Layout) -> Result<HeapSlot, AllocError>;

    /// Deallocates a [`HeapSlot`].
    ///
    /// OSTD calls this method to deallocate memory back to the global heap.
    ///
    /// Each deallocation must correspond to exactly one previous allocation. The provided
    /// [`HeapSlot`] must match the one returned from the original allocation.
    fn dealloc(&self, slot: HeapSlot) -> Result<(), AllocError>;
}

extern "Rust" {
    /// The reference to the global heap allocator generated by the
    /// [`crate::global_heap_allocator`] attribute.
    static __GLOBAL_HEAP_ALLOCATOR_REF: &'static dyn GlobalHeapAllocator;

    /// Gets the size and type of heap slots to serve allocations of the layout.
    /// See [`crate::global_heap_allocator_slot_map`].
    fn __GLOBAL_HEAP_SLOT_INFO_FROM_LAYOUT(layout: Layout) -> Option<SlotInfo>;
}

/// Gets the reference to the user-defined global heap allocator.
fn get_global_heap_allocator() -> &'static dyn GlobalHeapAllocator {
    // SAFETY: This up-call is redirected safely to Rust code by OSDK.
    unsafe { __GLOBAL_HEAP_ALLOCATOR_REF }
}

/// Gets the size and type of heap slots to serve allocations of the layout.
///
/// This function is defined by the OSTD user and should be idempotent, as we
/// require it to be implemented as a `const fn`.
///
/// See [`crate::global_heap_allocator_slot_map`].
fn slot_size_from_layout(layout: Layout) -> Option<SlotInfo> {
    // SAFETY: This up-call is redirected safely to Rust code by OSDK.
    unsafe { __GLOBAL_HEAP_SLOT_INFO_FROM_LAYOUT(layout) }
}

macro_rules! abort_with_message {
    ($($arg:tt)*) => {
        log::error!($($arg)*);
        crate::panic::abort();
    };
}

#[alloc_error_handler]
fn handle_alloc_error(layout: core::alloc::Layout) -> ! {
    abort_with_message!("Heap allocation error, layout = {:#x?}", layout);
}

#[global_allocator]
static HEAP_ALLOCATOR: AllocDispatch = AllocDispatch;

struct AllocDispatch;

// TODO: Somehow restrict unwinding in the user-provided global allocator.
// Panicking should be fine, but we shouldn't unwind on panics.
unsafe impl GlobalAlloc for AllocDispatch {
    unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
        let Some(required_slot) = slot_size_from_layout(layout) else {
            abort_with_message!("Heap allocation size not found for layout = {:#x?}", layout);
        };

        let res = get_global_heap_allocator().alloc(layout);
        let Ok(slot) = res else {
            return core::ptr::null_mut();
        };

        if required_slot.size() != slot.size()
            || slot.size() < layout.size()
            || slot.as_ptr() as Vaddr % layout.align() != 0
        {
            abort_with_message!(
                "Heap allocation mismatch: slot ptr = {:p}, size = {:x}; layout = {:#x?}; required_slot = {:#x?}",
                slot.as_ptr(),
                slot.size(),
                layout,
                required_slot,
            );
        }

        slot.as_ptr()
    }

    unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
        // Now we restore the `HeapSlot` from the pointer and the layout.
        let Some(required_slot) = slot_size_from_layout(layout) else {
            abort_with_message!(
                "Heap deallocation size not found for layout = {:#x?}",
                layout
            );
        };

        // SAFETY: The validity of the pointer is guaranteed by the caller. The
        // size must match the size of the slot when it was allocated, since we
        // require `slot_size_from_layout` to be idempotent.
        let slot = unsafe { HeapSlot::new(NonNull::new_unchecked(ptr), required_slot) };
        let res = get_global_heap_allocator().dealloc(slot);

        if res.is_err() {
            abort_with_message!(
                "Heap deallocation error, ptr = {:p}, layout = {:#x?}, required_slot = {:#x?}",
                ptr,
                layout,
                required_slot,
            );
        }
    }
}
